
# generate data from mul norm 
#source('generateData.R')
# ======================== Handle synthetic data  ========================
library('MASS') 
#return design matrix of size n x p, whose rows are independent Gaussian with cov decay r
n <- 1000
ps <- c(5,2,4) #block sizes
p <- sum(ps)
V <- matrix(0, nrow=p, ncol=p)
r <- 0.9
for (i in 1:length(ps)){
  pp <- ps[i]
  startInd <- sum(ps[1:i])-ps[i]
  for (k in 1:pp){ 
    for (j in 1:pp){
      V[startInd+k,startInd+j] <-  r^(abs(k-j))
    }
  }
}
X <- mvrnorm(n, mu = rep(0,p), Sigma = V)
dim(X)
V_X <- t(X) %*% X / n #sample cov


# ============================== Read data for gradient information ==========================================
# need to find an appropriate def of V 
data <- read.csv("states_qtr_031908.csv")
X <- data[,-1]
dim(X)
n <- 207
X <- as.matrix(X[1:n,])
p <- ncol(X)
dim(X)
#V <- t(X) %*% X / n #sample cov
V <- as.matrix(cor(X))
names(data[,-1])
opt_blocks




#in practice, we should run multiple init and then select the one with smallest entropy 
K <- 3
nInit <- 1000
Iter <- 100
min_entrop <- 0
for (iter in 1:nInit){
  res <- detection(Iter, K, V)
  if (res$entrop < min_entrop){
    min_entrop <- res$entrop
    opt_blocks <- res$blocks
    inv_mats <- res$inv_mats
  }
}
min_entrop
opt_blocks
#inv_mats
plot(res$adjust_his)
#plot(incre_his)
#install.packages(c("maps", "mapdata"))
#install.packages('ggmap')
library(ggplot2)
library(ggmap)
library(maps)
library(mapdata)
states <- map_data("state")
#define the clustering label for each state 
num <- states$region
states$cluster <- rep(0, num)
set_states <- unique(states$region)
set_states <- set_states[-which(set_states=='district of columbia')]
for (k in 1:K){
  states$cluster[ states$region %in% set_states[opt_blocks[[k]]] ] <- k
}
ggplot(data = states) + 
  geom_polygon(aes(x = long, y = lat, fill = cluster, group = group), color = "white") + 
  coord_fixed(1.3) +
  guides(fill=FALSE)  # do this to leave off the color legend



#=========================================================
#======== Functions =============================
#=========================================================

init_blocks <- function(K, p){
  ps <- sample(p)
  blocks <- list()
  block_size <- floor((p+1) / K)
  for (k in 1:(K-1)){
    blocks[[k]] <- sort( ps[((k-1)*block_size+1) : (k*block_size)] )
  }
  blocks[[K]] <- sort( ps[((K-1)*block_size+1) : p] )
  blocks
}

init_inv <- function(V, blocks){
  K <- length(blocks)
  inv_mats <- list()
  for (k in 1:K){
    inv_mats[[k]] <- solve(V[blocks[[k]], blocks[[k]]])
  }
  inv_mats
}

norm_vec <- function(x) sqrt(sum(x^2))

#cal. (I-u v')^{-1}
Sherman_formula <- function(u, v){
  n <- length(u)
  u <- as.matrix(u, ncol=1)
  v <- as.matrix(v, ncol=1)
  res <- diag(n) + u %*% t(v) / as.numeric(1-t(u) %*% v)
}


#In the update, if A becomes NULL, we need to randomly select an element from other blocks to add to A
#select the one that reduces the entropy most 
reboot <- function(V, blocks, inv_mats, indA){ 
  canset <- c()
  for (k in 1:K){
    if (length(blocks[[k]])>1) {
      canset <- c(canset, blocks[[k]])
    }
  }
  j  <- sample(canset, 1)
  #find which comm it belongs to 
  for (k in 1:K){
    if (j %in% blocks[[k]]) {
      indj <- k
    }
  }
  #move j from block indj to block A 
  blocks[[indA]] <- j 
  inv_mats[[indA]] <- 1/as.numeric(V[j,j])
  ind_j <- which(blocks[[indj]] == j)
  temp <- blocks[[indj]][-which(blocks[[indj]] %in% j)]
  inv_mats[[indj]] <- Sherman_formula(inv_mats[[indj]][-ind_j,ind_j], V[temp,j]) %*% inv_mats[[indj]][-ind_j,-ind_j]
  blocks[[indj]] <- temp
  
  res <- list(blocks = blocks, inv_mats = inv_mats)
}


#V: cov matrix 
#A, B: indices of two communities 
#iV_A, iV_B inverse matrix of V_A and V_B 
#random select one index from A and insert into B 
one_update_comm <- function(V, A, B, iV_A, iV_B){
  nA <- length(A)
  nB <- length(B)
  incre <- NA
  
  if (nA == 1){
    #if A has only one element, then decide whether merge A into B, 
    #if merge, must add another element in A to keep K comm. in total 
    i <- A
    u2 <- solve(V[B,B]) %*% V[B,i]
    c2 <- as.numeric( 1 / (V[i,i] - V[i,B] %*% u2) )
    incre <- c2 * (1+norm_vec(u2))
    incre <- -incre/2 #increment of within-comm entropy 
    if (incre < 0){
      B_plus <- union(B, i) #sort(union(B, i)) THE SORT IS WRONG! mess up the location
      
      iV_B_plus <- solve(V[B_plus,B_plus])
      # The following code has some unknown problem 
      # iV_B_plus <- matrix(rep(NA, (nB+1)**2), nrow=nB+1)
      # iV_B_plus[1:nB,1:nB] <- iV_B + c2*u2 %*% t(u2) 
      # iV_B_plus[1:nB,1+nB] <- -c2 * u2 
      # iV_B_plus[1+nB,1:nB] <- -c2 * t(u2)
      # iV_B_plus[1+nB,1+nB] <- c2
      
      #NOTE! here must use NA, instead of NULL, to ensure that the cell element does not disappear
      res <- list(A = NA, B = B_plus, iV_A = NA, iV_B = iV_B_plus, switch = TRUE, incre = incre)
    }else{
      res <- list(A = A, B = B, iV_A = iV_A, iV_B = iV_B, switch = FALSE, incre = incre)
    }
  
  }
  else{
    #if A has more than one element, then decide whether merge one of A into B
    i <- sample(A, 1)
    A_ <- A[-which(A %in% i)]
    u1 <- solve(V[A_,A_]) %*% V[A_,i]
    c1 <- as.numeric( 1 / (V[i,i] - V[i,A_] %*% u1) )
    u2 <- solve(V[B,B]) %*% V[B,i]
    c2 <- as.numeric( 1 / (V[i,i] - V[i,B] %*% u2) )
    incre <- - c1 * (1+norm_vec(u1)) + c2 * (1+norm_vec(u2))
    incre <- -incre/2 #increment of within-comm entropy 
    if (incre < 0){
      #switch 
      B_plus <- sort(union(B, i))
      iV_A_ <- diag(nA)
      ind_i <- which(A == i)
      
      iV_A_ <- solve(V[A_,A_])
      iV_B_plus <- solve(V[B_plus,B_plus])
      
      # THe following codes have some unknown problem 
      # iV_A_ <- Sherman_formula(iV_A[-ind_i,ind_i], V[A_,i]) %*% iV_A[-ind_i,-ind_i]
      # iV_B_plus <- matrix(rep(NA, (nB+1)**2), nrow=nB+1)
      # iV_B_plus[1:nB,1:nB] <- iV_B + c2*u2 %*% t(u2) 
      # iV_B_plus[1:nB,1+nB] <- -c2 * u2 
      # iV_B_plus[1+nB,1:nB] <- -c2 * t(u2)
      # iV_B_plus[1+nB,1+nB] <- c2
      
      res <- list(A = A_, B = B_plus, iV_A = iV_A_, iV_B = iV_B_plus, switch = TRUE, incre = incre)
    }else{
      res <- list(A = A, B = B, iV_A = iV_A, iV_B = iV_B, switch = FALSE, incre = incre)
    }
  }
  res
}

#test correctness 
# A <- c(1,6)
# B <- c(3,4,5)
# iV_A <- solve(V[A,A])
# iV_B <- solve(V[B,B])
# res <- one_update_comm(V, A, B, iV_A, iV_B)


detection <- function(Iter, K, V){ 
  #initialize 3 blocks randomly 
  blocks <- init_blocks(K,p)
  inv_mats <- init_inv(V, blocks)
  
  adjust_his <- rep(NA, 100) #whether adjusted 
  incre_his <- rep(NA, 100) #what incre of entropy 
  for (iter in 1:Iter){
    #random select two blocks 
    ind <- sample( seq(1,K,1), 2 )
    ind1 <- ind[1]
    ind2 <- ind[2]
    A <- blocks[[ind1]]
    B <- blocks[[ind2]]
    iV_A <- inv_mats[[ind1]]
    iV_B <- inv_mats[[ind2]]
    res <- one_update_comm(V, A, B, iV_A, iV_B)
    
    #print to debug
    is.na(res$A)
    blocks_memo <- blocks
    blocks_memo
    inv_mats_memo <- inv_mats
    
    blocks[[ind1]] <- res$A
    inv_mats[[ind1]] <- res$iV_A
    blocks[[ind2]] <- res$B
    inv_mats[[ind2]] <- res$iV_B
    adjust_his[iter] <- res$switch
    incre_his[iter] <- res$incre
    blocks
    
    if (sum(is.na(res$A))){ #sum is used only to depress warnings 
      res <- reboot(V, blocks, inv_mats, ind1)
      blocks <- res$blocks
      inv_mats <- res$inv_mats 
    }
    
  }
  entrop <- 0
  for (k in 1:K){
    entrop <- entrop - 0.5 * sum(diag(inv_mats[[k]]))
  }
  #entrop
  #calculate_entropy(V, blocks)
  list(entrop = entrop, blocks = blocks, inv_mats = inv_mats, adjust_his = adjust_his, incre_his = incre_his)
}


calculate_entropy <- function(V, blocks){
  V1 = V[blocks[[1]],blocks[[1]]]
  V2 = V[blocks[[2]],blocks[[2]]]
  V3 = V[blocks[[3]],blocks[[3]]]
  -0.5 * ( sum(diag(solve(V1))) + sum(diag(solve(V2))) + sum(diag(solve(V3))) )
}





  